#!/usr/bin/env python

"""
This script automates parallelisation of Dsuite Dtrios/ Dsuite DtriosCombine.
This script was tested with python 2.7 and 3.6. It only uses standard 
python libraries. Hence it should run on most systems with a standard 
python installation.
"""

from __future__ import (print_function, unicode_literals, division)

import os, sys
import subprocess
import argparse
import logging



logger = logging.getLogger()
logging.basicConfig(format='%(levelname)-8s %(asctime)s  %(message)s')
#logging.basicConfig(format='%(levelname)-8s %(asctime)s %(funcName)20s()  %(message)s')
logger.setLevel(logging.DEBUG)



def get_n_snps(vcf_fn):
    """
    Gets the number of variants in the VCF file.
    """
    vcf_ext = os.path.splitext(vcf_fn)[-1]
    catfun = 'gzip -dc' if vcf_ext in ['.gz','.bz'] else 'cat'
    
    logging.info("Checking number of variants in {vcf_fn}".format(vcf_fn=vcf_fn))

    command = '{catfun} {vcf_fn} | grep -v "^#" | wc -l'.format(catfun=catfun, vcf_fn=vcf_fn)
    p = subprocess.Popen(command,
                     shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    n_snps, err = p.communicate(b"input data that is passed to subprocess' stdin")
    if p.returncode:
        logging.error('{e}'.format(e=e))
        raise subprocess.CalledProcessError(p.returncode, command)

    n_snps = int(n_snps)
    logging.info("{n_snps} variants in vcf.".format(n_snps=n_snps))

    return n_snps
        
    
def dsuite_dtrios(vcf_fn, sets_fn, run_name, n_cores, tree_fn=None,JKnum=None,
                  JKwindow=None,
                  n_snps=None, dsuite_path='', environment_setup=''):
    
    if n_snps is None:
        n_snps = get_n_snps(vcf_fn)
    
    if dsuite_path and dsuite_path[-1] != '/':
        dsuite_path += '/'
    
    lines_per_core = int(n_snps*1./ n_cores) + 1
    
    logging.info("Parallelizing across {n_cores} cores. Regions of {lines_per_core} variants per core.".format(n_cores=n_cores, lines_per_core=lines_per_core))
    
    starts = range(1, n_snps, lines_per_core)


    if run_name is None:
        run_str = ''
    else:
        run_str = "--run-name {}".format(run_name)


    if tree_fn is None:
        tree_str = ''
    else:
        tree_str = "--tree {}".format(os.path.abspath(tree_fn))

    if JKnum is None:
        jknum_str = ""
    else:
        jknum_str = "--JKnum {}".format(JKnum)

    if JKwindow is None:
        jkwindow_str = ""
    else:
        jkwindow_str = "--JKwindow {}".format(JKwindow)

    processes = []
    commands = []
    for start in starts:
        command_str = ('{environment_setup} {dsuite_path}Dsuite Dtrios {run_str} {tree_str} {jknum_str}'
                       '{jkwindow_str}  --region={start},{lines_per_core} {vcf_fn} {sets_fn}'.format(
                           environment_setup=environment_setup, dsuite_path=dsuite_path,
                           run_str=run_str, tree_str=tree_str, 
                           jknum_str=jknum_str, jkwindow_str=jkwindow_str, start=start, 
                           lines_per_core=lines_per_core, vcf_fn=vcf_fn, sets_fn=sets_fn))
        logging.info('Starting process: {command_str}'.format(command_str=command_str))
        p = subprocess.Popen(command_str,
                         shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        commands.append(command_str)
        processes.append(p)


    for s, c, p in zip(starts, commands, processes):
        o, e = p.communicate()
        o = o.decode('utf-8')
        e = e.decode('utf-8')
        if o:
            print(o, file=sys.stdout)
        if p.returncode:
            logging.error(e)
            raise subprocess.CalledProcessError(p.returncode, c)
        else:
            if e:
                logging.debug('{e}'.format(e=e))
            logging.info('Successfully finished process for interval starting at {s}'.format(s=s))
            #print('------------------------------------------------------', file=sys.stderr)
    
    return lines_per_core, starts 
    
def dsuite_combine(sets_fn, run_name, starts, lines_per_core, tree_fn=None,
                   remove_intermediate_files=True, dsuite_path='',  environment_setup=''):
    
    
    extensions = ['_combine.txt', '_combine_stderr.txt', '_BBAA.txt', '_Dmin.txt']
    
    sets_dir = os.path.dirname(os.path.abspath(sets_fn))
    sets_base = os.path.splitext(os.path.abspath(sets_fn))[0]

    if dsuite_path and dsuite_path[-1] != '/':
        dsuite_path += '/'

    if run_name is None:
        run_str = ""
        run_name = ""
    else:
        run_str = "--run-name " + run_name

    if tree_fn is None:
        tree_str = ''
    else:
        tree_str = "--tree {}".format(os.path.abspath(tree_fn))

    outbases = ['{sets_base}_{run_name}_{s}_{e}'.format(sets_base=sets_base, run_name=run_name,s=s,
                                                        e=s + lines_per_core) for s in starts]

    combine_command = 'cd {sets_dir}; {environment_setup} {dsuite_path}Dsuite DtriosCombine --out-prefix out {run_str} {tree_str} '.format(
                        sets_dir=sets_dir,
                        environment_setup=environment_setup, dsuite_path=dsuite_path, run_str=run_str, tree_str=tree_str) \
                        + ' '.join(outbases)


    #print('------------------------------------------------------', file=sys.stderr)
    logging.info('Combining output from {} runs with: '.format(len(starts)) + combine_command)

    p = subprocess.Popen( combine_command,
                         shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)

    o, e = p.communicate()
    o = o.decode('utf-8')
    e = e.decode('utf-8')
    if o:
        print(o, file=sys.stdout)
    if p.returncode:
        if e:
            logging.error(e)
        raise subprocess.CalledProcessError(p.returncode, combine_command)
    else:
        if e:
            logging.debug(e)
            logging.info('Successfully combined {} runs into outoput files'
              ' with base {}.'.format(len(starts),
                                                os.path.join(sets_dir,"out_" +
                                                            run_name+'_combined')))

    #print('------------------------------------------------------', file=sys.stderr)

    if remove_intermediate_files:
        
        logging.info('Removing intermediate files.')
        intermediate_files = []
        for b in outbases:
            for ex in extensions:
                intermediate_files.append(b+ex)

        p = subprocess.Popen('rm {}'.format(' '.join(intermediate_files)),
                             shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)

        o, e = p.communicate()
        o = o.decode('utf-8')
        e = e.decode('utf-8')
        if o:
            print(o, file=sys.stdout)
        if e:
            logging.error(o, file=sys.stderr)
            
    return p.returncode 




def main():
    parser = argparse.ArgumentParser(description=("This python script automates parallelisation of Dsuite Dtrios/ Dsuite DtriosCombine. "
                                                  "The usage is analogous to Dsuite Dtrios but computation is performed "
                                                  " on multiple cores (default: number of available CPUs). "
                                                  "It should run on most systems with a standard python installation "
                                                  "(tested with python 2.7 and 3.6)."))
    parser.add_argument("vcf_fn", metavar="INPUT_FILE.vcf")
    parser.add_argument("sets_fn", metavar="SETS.txt",
                        help=("The SETS.txt should have two columns: SAMPLE_ID    SPECIES_ID\n"
                        "The outgroup (can be multiple samples) should be specified by using the "
                        "keyword Outgroup in place of the SPECIES_ID"))
    parser.add_argument("--cores", type=int,
                                help=("(default=CPU count) Number of Dsuite Dtrios processes run in parallel."),
                                                        default=None)     
    parser.add_argument("-k", "--JKnum", type=int,
                                help=("(default=20) the number of Jackknife blocks to divide the dataset into;"
                                      " should be at least 20 for the whole dataset"),
                                                        default=None)
    parser.add_argument("-j", "--JKwindow", type=int,
                                help=("Jackknife block size in number of informative SNPs (as used in v0.2)"
                                      " when specified, this is used in place of the --JKnum option"),
                                                        required=False)

    parser.add_argument("-t", "--tree", type=str,
                        help=("a file with a tree in the newick format specifying the relationships between populations/species"
                             " D and f4-ratio values for trios arranged according to the tree will be output in a file with _tree.txt suffix"),
                        required=False)

    parser.add_argument("-n", "--run-name", type=str,
                        help="run-name will be included in the output file name",
                        required=False)
    parser.add_argument( "--keep-intermediate", action='store_true',
                        help="Keep region-wise Dsuite Dtrios results.")
    parser.add_argument('--logging_level','-l',
                                    choices=['DEBUG','INFO','WARNING','ERROR','CRITICAL'],
                                                            default='INFO',
                                                            help='Minimun level of logging.')
    parser.add_argument('--dsuite-path',type=str, required=False,
                        help="Explicitly set the path to the directory in which Dsuite is located. By default the script will first check"
                            " whether Dsuite is accessible from $PATH. "
                             " If not it will try to locate Dsuite at ../Build/Dsuite.")
    parser.add_argument('--environment-setup',type=str, required=False,
                        help="Command that should be run to setup the environment for Dsuite. E.g., 'module load GCC' or 'conda activate'")
    


    args, unknown = parser.parse_known_args()



    logger.setLevel(getattr(logging, args.logging_level))


    if unknown:
        logger.warning("The following unrecognized arguments are not used: {}".format(unknown))
   
    def which(program):
        def is_exe(fpath):
            return os.path.isfile(fpath) and os.access(fpath, os.X_OK)

        fpath, fname = os.path.split(program)
        if fpath:
            if is_exe(program):
                return program
        else:
            for path in os.environ["PATH"].split(os.pathsep):
                exe_file = os.path.join(path, program)
                if is_exe(exe_file):
                    return exe_file

        return None


    #check whether Dsuite is accessible
    if args.dsuite_path is None:
        dsuite_path = which("Dsuite")
        if dsuite_path is not None:
            args.dsuite_path = ''
        else: 
            args.dsuite_path = os.path.abspath(os.path.join(os.path.dirname( __file__ ), '..', 'Build'))

    if args.environment_setup is None:
        args.environment_setup = ""
    else:
        args.environment_setup += ';'
    
    if args.cores is None:

        if sys.version_info[0] < 3:
            import multiprocessing
            args.cores = multiprocessing.cpu_count()
        else:
            args.cores = os.cpu_count()
    if args.JKwindow is not None:
        args.JKnum = None


    vcf_fn = os.path.abspath(args.vcf_fn)
    sets_fn = os.path.abspath(args.sets_fn)

    n_snps = get_n_snps(args.vcf_fn)
    
    lines_per_core, starts = dsuite_dtrios(vcf_fn, sets_fn, args.run_name, args.cores, tree_fn=args.tree,
                                            n_snps=n_snps, JKnum=args.JKnum, JKwindow=args.JKwindow,dsuite_path=args.dsuite_path,
                                            environment_setup=args.environment_setup)
    
    rc = dsuite_combine(sets_fn, args.run_name, starts, lines_per_core, tree_fn=args.tree,
                        remove_intermediate_files=not args.keep_intermediate, dsuite_path=args.dsuite_path, 
                                                                    environment_setup=args.environment_setup)


    return rc


if __name__ == "__main__":
        sys.exit(main())

